{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4f0f701f-c552-4ca1-8188-2cdfc1362f6b",
   "metadata": {},
   "source": [
    "# Uni-Mol Molecular Representation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3449ed8-2a57-4e62-9163-e32baf66e828",
   "metadata": {},
   "source": [
    "**Licenses**\n",
    "\n",
    "Copyright (c) DP Technology.\n",
    "\n",
    "This source code is licensed under the MIT license found in the\n",
    "LICENSE file in the root directory of this source tree.\n",
    "\n",
    "**Citations**\n",
    "\n",
    "Please cite the following papers if you use this notebook:\n",
    "\n",
    "- Gengmo Zhou, Zhifeng Gao, Qiankun Ding, Hang Zheng, Hongteng Xu, Zhewei Wei, Linfeng Zhang, Guolin Ke. \"[Uni-Mol: A Universal 3D Molecular Representation Learning Framework.](https://chemrxiv.org/engage/chemrxiv/article-details/6318b529bada388485bc8361)\"\n",
    "ChemRxiv (2022)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d51f850-76cd-4801-bf2e-a4c53221d586",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import lmdb\n",
    "from rdkit import Chem\n",
    "from rdkit.Chem import AllChem\n",
    "from tqdm import tqdm\n",
    "import pickle\n",
    "import glob\n",
    "from multiprocessing import Pool\n",
    "from collections import defaultdict"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89c70ab0-da59-459d-bf1c-ac307e9e7ae5",
   "metadata": {},
   "source": [
    "### Your SMILES list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfa0ce2a-b7aa-4cae-81ba-27b91c0591e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "smi_list = [\n",
    "'CC1=C(C(=O)OC2CCCC2)[C@H](c2ccccc2OC(C)C)C2=C(O)CC(C)(C)CC2=[N+]1',\n",
    "'COc1cccc(-c2nc(C(=O)NC[C@H]3CCCO3)cc3c2[nH]c2ccccc23)c1',\n",
    "'O=C1c2ccccc2C(=O)c2c1ccc(C(=O)n1nc3c4c(cccc41)C(=O)c1ccccc1-3)c2[N+](=O)[O-]',\n",
    "'COc1cc(/C=N/c2nonc2NC(C)=O)ccc1OC(C)C',\n",
    "'CCC[C@@H]1CN(Cc2ccc3nsnc3c2)C[C@H]1NS(C)(=O)=O',\n",
    "'CCc1nnc(N/C(O)=C/CCOc2ccc(OC)cc2)s1',\n",
    "'CC(C)(C)SCCN/C=C1\\C(=O)NC(=O)N(c2ccc(Br)cc2)C1=O',\n",
    "'CC(C)(C)c1nc(COc2ccc3c(c2)CCn2c-3cc(OCC3COCCO3)nc2=O)no1',\n",
    "'N#CCCNS(=O)(=O)c1ccc(/C(O)=N/c2ccccc2Oc2ccccc2Cl)cc1',\n",
    "'O=C(Nc1ncc(Cl)s1)c1cccc(S(=O)(=O)Nc2ccc(Br)cc2)c1',\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b109d84a-8d59-445b-9997-d1383ee24079",
   "metadata": {},
   "source": [
    "### Generate conformations from SMILES and save to .lmdb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea582d7d-8851-4d46-880e-54867737b232",
   "metadata": {},
   "outputs": [],
   "source": [
    "def smi2_2Dcoords(smi):\n",
    "    mol = Chem.MolFromSmiles(smi)\n",
    "    mol = AllChem.AddHs(mol)\n",
    "    AllChem.Compute2DCoords(mol)\n",
    "    coordinates = mol.GetConformer().GetPositions().astype(np.float32)\n",
    "    len(mol.GetAtoms()) == len(coordinates), \"2D coordinates shape is not align with {}\".format(smi)\n",
    "    return coordinates\n",
    "\n",
    "\n",
    "def smi2_3Dcoords(smi,cnt):\n",
    "    mol = Chem.MolFromSmiles(smi)\n",
    "    mol = AllChem.AddHs(mol)\n",
    "    coordinate_list=[]\n",
    "    for seed in range(cnt):\n",
    "        try:\n",
    "            res = AllChem.EmbedMolecule(mol, randomSeed=seed)  # will random generate conformer with seed equal to -1. else fixed random seed.\n",
    "            if res == 0:\n",
    "                try:\n",
    "                    AllChem.MMFFOptimizeMolecule(mol)       # some conformer can not use MMFF optimize\n",
    "                    coordinates = mol.GetConformer().GetPositions()\n",
    "                except:\n",
    "                    print(\"Failed to generate 3D, replace with 2D\")\n",
    "                    coordinates = smi2_2Dcoords(smi)            \n",
    "                    \n",
    "            elif res == -1:\n",
    "                mol_tmp = Chem.MolFromSmiles(smi)\n",
    "                AllChem.EmbedMolecule(mol_tmp, maxAttempts=5000, randomSeed=seed)\n",
    "                mol_tmp = AllChem.AddHs(mol_tmp, addCoords=True)\n",
    "                try:\n",
    "                    AllChem.MMFFOptimizeMolecule(mol_tmp)       # some conformer can not use MMFF optimize\n",
    "                    coordinates = mol_tmp.GetConformer().GetPositions()\n",
    "                except:\n",
    "                    print(\"Failed to generate 3D, replace with 2D\")\n",
    "                    coordinates = smi2_2Dcoords(smi) \n",
    "        except:\n",
    "            print(\"Failed to generate 3D, replace with 2D\")\n",
    "            coordinates = smi2_2Dcoords(smi) \n",
    "\n",
    "        assert len(mol.GetAtoms()) == len(coordinates), \"3D coordinates shape is not align with {}\".format(smi)\n",
    "        coordinate_list.append(coordinates.astype(np.float32))\n",
    "    return coordinate_list\n",
    "\n",
    "\n",
    "def inner_smi2coords(content):\n",
    "    smi = content\n",
    "    cnt = 10 # conformer num,all==11, 10 3d + 1 2d\n",
    "\n",
    "    mol = Chem.MolFromSmiles(smi)\n",
    "    if len(mol.GetAtoms()) > 400:\n",
    "        coordinate_list =  [smi2_2Dcoords(smi)] * (cnt+1)\n",
    "        print(\"atom num >400,use 2D coords\",smi)\n",
    "    else:\n",
    "        coordinate_list = smi2_3Dcoords(smi,cnt)\n",
    "        # add 2d conf\n",
    "        coordinate_list.append(smi2_2Dcoords(smi).astype(np.float32))\n",
    "    mol = AllChem.AddHs(mol)\n",
    "    atoms = [atom.GetSymbol() for atom in mol.GetAtoms()]  # after add H \n",
    "    return pickle.dumps({'atoms': atoms, 'coordinates': coordinate_list, 'smi': smi }, protocol=-1)\n",
    "\n",
    "\n",
    "def smi2coords(content):\n",
    "    try:\n",
    "        return inner_smi2coords(content)\n",
    "    except:\n",
    "        print(\"failed smiles: {}\".format(content[0]))\n",
    "        return None\n",
    "\n",
    "\n",
    "def write_lmdb(smiles_list, job_name, seed=42, outpath='./results', nthreads=8):\n",
    "    os.makedirs(outpath, exist_ok=True)\n",
    "    output_name = os.path.join(outpath,'{}.lmdb'.format(job_name))\n",
    "    try:\n",
    "        os.remove(output_name)\n",
    "    except:\n",
    "        pass\n",
    "    env_new = lmdb.open(\n",
    "        output_name,\n",
    "        subdir=False,\n",
    "        readonly=False,\n",
    "        lock=False,\n",
    "        readahead=False,\n",
    "        meminit=False,\n",
    "        max_readers=1,\n",
    "        map_size=int(100e9),\n",
    "    )\n",
    "    txn_write = env_new.begin(write=True)\n",
    "    with Pool(nthreads) as pool:\n",
    "        i = 0\n",
    "        for inner_output in tqdm(pool.imap(smi2coords, smiles_list)):\n",
    "            if inner_output is not None:\n",
    "                txn_write.put(f'{i}'.encode(\"ascii\"), inner_output)\n",
    "                i += 1\n",
    "        print('{} process {} lines'.format(job_name, i))\n",
    "        txn_write.commit()\n",
    "        env_new.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dad25a1a-f93e-4fdf-b389-2a3fe61a40ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 42\n",
    "job_name = 'get_mol_repr'   # replace to your custom name\n",
    "data_path = './results'  # replace to your data path\n",
    "weight_path='../ckp/mol_pre_no_h_220816.pt'  # replace to your ckpt path\n",
    "only_polar=0  # no h\n",
    "dict_name='dict.txt'\n",
    "batch_size=16\n",
    "conf_size=11  # default 10 3d + 1 2d\n",
    "results_path=data_path   # replace to your save path\n",
    "write_lmdb(smi_list, job_name=job_name, seed=seed, outpath=data_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12284210-7f86-4062-b291-7c077ef6f83a",
   "metadata": {},
   "source": [
    "### Infer from ckpt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fb2391b-81b0-4b11-95ea-3b7855db9bc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# NOTE: Currently, the inference is only supported to run on a single GPU. You can add CUDA_VISIBLE_DEVICES=\"0\" before the command.\n",
    "!cp ../example_data/molecule/$dict_name $data_path\n",
    "!CUDA_VISIBLE_DEVICES=\"0\" python ../unimol/infer.py --user-dir ../unimol $data_path --valid-subset $job_name \\\n",
    "       --results-path $results_path \\\n",
    "       --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \\\n",
    "       --task unimol --loss unimol_infer --arch unimol_base \\\n",
    "       --path $weight_path \\\n",
    "       --only-polar $only_polar --dict-name $dict_name --conf-size $conf_size \\\n",
    "       --log-interval 50 --log-format simple --random-token-prob 0 --leave-unmasked-prob 1.0 --mode infer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8421258-eca6-4801-aadd-fc67fd928cb1",
   "metadata": {},
   "source": [
    "### Read .pkl and save results to .csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c456f31e-94fc-4593-97c9-1db7182465aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_csv_results(predict_path, results_path):\n",
    "    predict = pd.read_pickle(predict_path)\n",
    "    mol_repr_dict = defaultdict(list)\n",
    "    atom_repr_dict = defaultdict(list)\n",
    "    pair_repr_dict = defaultdict(list)\n",
    "    for batch in predict:\n",
    "        sz = batch[\"bsz\"]\n",
    "        for i in range(sz):\n",
    "            smi = batch[\"data_name\"][i]\n",
    "            mol_repr_dict[smi].append(batch[\"mol_repr_cls\"][i])\n",
    "            atom_repr_dict[smi].append(batch[\"atom_repr\"][i])\n",
    "            pair_repr_dict[smi].append(batch[\"pair_repr\"][i])\n",
    "    # get mean repr for each molecule with multiple conf\n",
    "    smi_list, avg_mol_repr_list, avg_atom_repr_list, avg_pair_repr_list = [], [], [], []\n",
    "    for smi in mol_repr_dict.keys():\n",
    "        smi_list.append(smi)\n",
    "        avg_mol_repr_list.append(np.mean(mol_repr_dict[smi], axis=0))\n",
    "        avg_atom_repr_list.append(np.mean(atom_repr_dict[smi], axis=0))\n",
    "        avg_pair_repr_list.append(np.mean(pair_repr_dict[smi], axis=0))\n",
    "    predict_df = pd.DataFrame({\n",
    "    \"SMILES\": smi_list,\n",
    "    \"mol_repr\": avg_mol_repr_list,\n",
    "    \"atom_repr\": avg_atom_repr_list,\n",
    "    \"pair_repr\": avg_pair_repr_list\n",
    "    })\n",
    "    print(predict_df.head(1),predict_df.info())\n",
    "    predict_df.to_csv(results_path+'/mol_repr.csv',index=False)\n",
    "\n",
    "pkl_path = glob.glob(f'{results_path}/*_{job_name}.out.pkl')[0]\n",
    "get_csv_results(pkl_path, results_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.6 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
